Build a Conditional GAN

Goals

In this notebook, you're going to make a conditional GAN in order to generate hand-written images of digits, conditioned on the digit to be generated (the class vector). This will let you choose what digit you want to generate.

You'll then do some exploration of the generated images to visualize what the noise and class vectors mean.

Learning Objectives

  1. Learn the technical difference between a conditional and unconditional GAN.
  2. Understand the distinction between the class and noise vector in a conditional GAN.

Getting Started

For this assignment, you will be using the MNIST dataset again, but there's nothing stopping you from applying this generator code to produce images of animals conditioned on the species or pictures of faces conditioned on facial characteristics.

Note that this assignment requires no changes to the architectures of the generator or discriminator, only changes to the data passed to both. The generator will no longer take z_dim as an argument, but input_dim instead, since you need to pass in both the noise and class vectors. In addition to good variable naming, this also means that you can use the generator and discriminator code you have previously written with different parameters.

You will begin by importing the necessary libraries and building the generator and discriminator.

Packages and Visualization

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()

Generator and Noise

In [2]:
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        input_dim: the dimension of the input vector, a scalar
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, input_dim)
        '''
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, input_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, input_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        input_dim: the dimension of the input vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples, input_dim, device=device)

Discriminator

In [3]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
      im_chan: the number of channels in the images, fitted for the dataset used, a scalar
            (MNIST is black-and-white, so 1 channel is your default)
      hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a discriminator block of the DCGAN; 
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

Class Input

In conditional GANs, the input vector for the generator will also need to include the class information. The class is represented using a one-hot encoded vector where its length is the number of classes and each index represents a class. The vector is all 0's and a 1 on the chosen class. Given the labels of multiple images (e.g. from a batch) and number of classes, please create one-hot vectors for each label. There is a class within the PyTorch functional library that can help you.

Optional hints for get_one_hot_labels 1. This code can be done in one line. 2. The documentation for [F.one_hot](https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.one_hot) may be helpful.
In [4]:
# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_one_hot_labels

import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
    '''
    Function for creating one-hot vectors for the labels, returns a tensor of shape (?, num_classes).
    Parameters:
        labels: tensor of labels from the dataloader, size (?)
        n_classes: the total number of classes in the dataset, an integer scalar
    '''
    #### START CODE HERE ####
    return F.one_hot(labels, num_classes=n_classes)
    #### END CODE HERE ####
In [5]:
assert (
    get_one_hot_labels(
        labels=torch.Tensor([[0, 2, 1]]).long(),
        n_classes=3
    ).tolist() == 
    [[
      [1, 0, 0], 
      [0, 0, 1], 
      [0, 1, 0]
    ]]
)
print("Success!")
Success!

Next, you need to be able to concatenate the one-hot class vector to the noise vector before giving it to the generator. You will also need to do this when adding the class channels to the discriminator.

To do this, you will need to write a function that combines two vectors. Remember that you need to ensure that the vectors are the same type: floats. Again, you can look to the PyTorch library for help.

Optional hints for combine_vectors 1. This code can also be written in one line. 2. The documentation for [torch.cat](https://pytorch.org/docs/master/generated/torch.cat.html) may be helpful. 3. Specifically, you might want to look at what the `dim` argument of `torch.cat` does.
In [6]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: combine_vectors
def combine_vectors(x, y):
    '''
    Function for combining two vectors with shapes (n_samples, ?) and (n_samples, ?).
    Parameters:
      x: (n_samples, ?) the first vector. 
        In this assignment, this will be the noise vector of shape (n_samples, z_dim), 
        but you shouldn't need to know the second dimension's size.
      y: (n_samples, ?) the second vector.
        Once again, in this assignment this will be the one-hot class vector 
        with the shape (n_samples, n_classes), but you shouldn't assume this in your code.
    '''
    # Note: Make sure this function outputs a float no matter what inputs it receives
    #### START CODE HERE ####
    combined = torch.cat((x.float(), y.float()), dim=1)
    #### END CODE HERE ####
    return combined
In [7]:
combined = combine_vectors(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]]));
# Check exact order of elements
assert torch.all(combined == torch.tensor([[1, 2, 5, 6], [3, 4, 7, 8]]))
# Tests that items are of float type
assert (type(combined[0][0].item()) == float)
# Check shapes
combined = combine_vectors(torch.randn(1, 4, 5), torch.randn(1, 8, 5));
assert tuple(combined.shape) == (1, 12, 5)
assert tuple(combine_vectors(torch.randn(1, 10, 12).long(), torch.randn(1, 20, 12).long()).shape) == (1, 30, 12)
print("Success!")
Success!

Training

Now you can start to put it all together! First, you will define some new parameters:

  • mnist_shape: the number of pixels in each MNIST image, which has dimensions 28 x 28 and one channel (because it's black-and-white) so 1 x 28 x 28
  • n_classes: the number of classes in MNIST (10, since there are the digits from 0 to 9)
In [8]:
mnist_shape = (1, 28, 28)
n_classes = 10

And you also include the same parameters from previous assignments:

  • criterion: the loss function
  • n_epochs: the number of times you iterate through the entire dataset when training
  • z_dim: the dimension of the noise vector
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • device: the device type
In [9]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Then, you can initialize your generator, discriminator, and optimizers. To do this, you will need to update the input dimensions for both models. For the generator, you will need to calculate the size of the input vector; recall that for conditional GANs, the generator's input is the noise vector concatenated with the class vector. For the discriminator, you need to add a channel for every class.

In [10]:
# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_input_dimensions
def get_input_dimensions(z_dim, mnist_shape, n_classes):
    '''
    Function for getting the size of the conditional input dimensions 
    from z_dim, the image shape, and number of classes.
    Parameters:
        z_dim: the dimension of the noise vector, a scalar
        mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
        n_classes: the total number of classes in the dataset, an integer scalar
                (10 for MNIST)
    Returns: 
        generator_input_dim: the input dimensionality of the conditional generator, 
                          which takes the noise and class vectors
        discriminator_im_chan: the number of input channels to the discriminator
                            (e.g. C x 28 x 28 for MNIST)
    '''
    #### START CODE HERE ####
    generator_input_dim = z_dim + n_classes
    discriminator_im_chan = mnist_shape[0] + n_classes
    #### END CODE HERE ####
    return generator_input_dim, discriminator_im_chan
In [11]:
def test_input_dims():
    gen_dim, disc_dim = get_input_dimensions(23, (12, 23, 52), 9)
    assert gen_dim == 32
    assert disc_dim == 21
test_input_dims()
print("Success!")
Success!
In [12]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

Now to train, you would like both your generator and your discriminator to know what class of image should be generated. There are a few locations where you will need to implement code.

For example, if you're generating a picture of the number "1", you would need to:

  1. Tell that to the generator, so that it knows it should be generating a "1"
  2. Tell that to the discriminator, so that it knows it should be looking at a "1". If the discriminator is told it should be looking at a 1 but sees something that's clearly an 8, it can guess that it's probably fake

There are no explicit unit tests here -- if this block of code runs and you don't change any of the other variables, then you've done it correctly!

In [13]:
# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CELL
cur_step = 0
generator_losses = []
discriminator_losses = []

#UNIT TEST NOTE: Initializations needed for grading
noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False

for epoch in range(n_epochs):
    # Dataloader returns the batches and the labels
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        # Flatten the batch of real images from the dataset
        real = real.to(device)

        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

        ### Update discriminator ###
        # Zero out the discriminator gradients
        disc_opt.zero_grad()
        # Get noise corresponding to the current batch_size 
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        
        # Now you can get the images from the generator
        # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
        #        2) Generate the conditioned fake images
       
        #### START CODE HERE ####
        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
        fake = gen(noise_and_labels)
        #### END CODE HERE ####
        
        # Make sure that enough images were generated
        assert len(fake) == len(real)
        # Check that correct tensors were combined
        assert tuple(noise_and_labels.shape) == (cur_batch_size, fake_noise.shape[1] + one_hot_labels.shape[1])
        # It comes from the correct generator
        assert tuple(fake.shape) == (len(real), 1, 28, 28)

        # Now you can get the predictions from the discriminator
        # Steps: 1) Create the input for the discriminator
        #           a) Combine the fake images with image_one_hot_labels, 
        #              remember to detach the generator (.detach()) so you do not backpropagate through it
        #           b) Combine the real images with image_one_hot_labels
        #        2) Get the discriminator's prediction on the fakes as disc_fake_pred
        #        3) Get the discriminator's prediction on the reals as disc_real_pred
        
        #### START CODE HERE ####
        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        real_image_and_labels = combine_vectors(real, image_one_hot_labels)
        disc_fake_pred = disc(fake_image_and_labels.detach())
        disc_real_pred = disc(real_image_and_labels)
        #### END CODE HERE ####
        
        # Make sure shapes are correct 
        assert tuple(fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], 28 ,28)
        assert tuple(real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1], 28 ,28)
        # Make sure that enough predictions were made
        assert len(disc_real_pred) == len(real)
        # Make sure that the inputs are different
        assert torch.any(fake_image_and_labels != real_image_and_labels)
        # Shapes must match
        assert tuple(fake_image_and_labels.shape) == tuple(real_image_and_labels.shape)
        assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)
        
        
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph=True)
        disc_opt.step() 

        # Keep track of the average discriminator loss
        discriminator_losses += [disc_loss.item()]

        ### Update generator ###
        # Zero out the generator gradients
        gen_opt.zero_grad()

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        # This will error if you didn't concatenate your labels to your image correctly
        disc_fake_pred = disc(fake_image_and_labels)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the generator losses
        generator_losses += [gen_loss.item()]
        #

        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator Loss"
            )
            plt.legend()
            plt.show()
        elif cur_step == 0:
            print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
        cur_step += 1
Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!

Step 500: Generator loss: 2.390482585787773, discriminator loss: 0.24346981071308255

Step 1000: Generator loss: 4.2990237874984745, discriminator loss: 0.039926913648843765

Step 1500: Generator loss: 4.01603684091568, discriminator loss: 0.07471435723081231

Step 2000: Generator loss: 3.448892023563385, discriminator loss: 0.11762107384204865

Step 2500: Generator loss: 2.768935915470123, discriminator loss: 0.19791465669870376

Step 3000: Generator loss: 2.286963224887848, discriminator loss: 0.2935032919794321

Step 3500: Generator loss: 1.9609758868217468, discriminator loss: 0.3266635043025017

Step 4000: Generator loss: 1.9790613369941712, discriminator loss: 0.34217212396860125

Step 4500: Generator loss: 1.9155494804382325, discriminator loss: 0.37460210391879084

Step 5000: Generator loss: 1.8606488909721375, discriminator loss: 0.38525806948542596

Step 5500: Generator loss: 1.7730398552417754, discriminator loss: 0.39938051271438596

Step 6000: Generator loss: 1.6318049784898758, discriminator loss: 0.42749992617964744

Step 6500: Generator loss: 1.5462016422748566, discriminator loss: 0.4698652623295784

Step 7000: Generator loss: 1.522529669046402, discriminator loss: 0.4905006697773933

Step 7500: Generator loss: 1.4747677425146102, discriminator loss: 0.48235233628749846


Step 8000: Generator loss: 1.3633231954574585, discriminator loss: 0.48055326330661774

Step 8500: Generator loss: 1.3579840159416199, discriminator loss: 0.4856787996292114

Step 9000: Generator loss: 1.4126235563755036, discriminator loss: 0.4979951973557472

Step 9500: Generator loss: 1.310292180299759, discriminator loss: 0.5060559151768684

Step 10000: Generator loss: 1.2593490180969238, discriminator loss: 0.5135202271342277

Step 10500: Generator loss: 1.3024956947565078, discriminator loss: 0.5259441373944282

Step 11000: Generator loss: 1.2247772790193558, discriminator loss: 0.5327630413770675

Step 11500: Generator loss: 1.162479483127594, discriminator loss: 0.5384246173501015

Step 12000: Generator loss: 1.237057696223259, discriminator loss: 0.5497304468154908

Step 12500: Generator loss: 1.2177826163768768, discriminator loss: 0.5453137177228927

Step 13000: Generator loss: 1.1267481179237366, discriminator loss: 0.5614456434249878

Step 13500: Generator loss: 1.1289535319805146, discriminator loss: 0.5586145658493042

Step 14000: Generator loss: 1.1069848237037658, discriminator loss: 0.5657509679794311

Step 14500: Generator loss: 1.0906270860433578, discriminator loss: 0.5702241674661637

Step 15000: Generator loss: 1.1034155098199845, discriminator loss: 0.5782030074596405


Step 15500: Generator loss: 1.0870768781900406, discriminator loss: 0.5779302825331688

Step 16000: Generator loss: 1.079880796432495, discriminator loss: 0.5773033278584481

Step 16500: Generator loss: 1.0585279194116592, discriminator loss: 0.5777821787595749

Step 17000: Generator loss: 1.054906175017357, discriminator loss: 0.5882963989973068

Step 17500: Generator loss: 1.0711934096813203, discriminator loss: 0.5761874537467957

Step 18000: Generator loss: 1.0662164595127106, discriminator loss: 0.577499737560749

Step 18500: Generator loss: 1.047797438621521, discriminator loss: 0.5853800494670868

Step 19000: Generator loss: 1.053154159426689, discriminator loss: 0.5831457045674324

Step 19500: Generator loss: 1.0087834986448287, discriminator loss: 0.5842429170608521

Step 20000: Generator loss: 1.050440570116043, discriminator loss: 0.5836795452833176

Step 20500: Generator loss: 1.0245925661325455, discriminator loss: 0.5867806885242463

Step 21000: Generator loss: 1.0281104496717453, discriminator loss: 0.5866698114871979

Step 21500: Generator loss: 1.0010785210132598, discriminator loss: 0.5891426213979721

Step 22000: Generator loss: 1.0456896343231201, discriminator loss: 0.5827461982369423

Step 22500: Generator loss: 1.0515592069625854, discriminator loss: 0.5963578228354454


Step 23000: Generator loss: 1.0113444550037385, discriminator loss: 0.5978357820510865

Step 23500: Generator loss: 0.9930306963920593, discriminator loss: 0.5956108877658844

Step 24000: Generator loss: 0.9821271594762803, discriminator loss: 0.5940639829039573

Step 24500: Generator loss: 0.9771805860996247, discriminator loss: 0.5974257956147194

Step 25000: Generator loss: 0.9956368246078491, discriminator loss: 0.6001199697852134

Step 25500: Generator loss: 0.981632101893425, discriminator loss: 0.5998036608099937

Step 26000: Generator loss: 0.9542368501424789, discriminator loss: 0.6049897282719612

Step 26500: Generator loss: 0.9782687791585922, discriminator loss: 0.6032342752218246

Step 27000: Generator loss: 0.9727148976325989, discriminator loss: 0.6082928144931793

Step 27500: Generator loss: 0.9843514202833176, discriminator loss: 0.6087211855649948

Step 28000: Generator loss: 0.9833866637945176, discriminator loss: 0.6081592819690704

Step 28500: Generator loss: 0.9509779886007309, discriminator loss: 0.6065650040507317

Step 29000: Generator loss: 0.9555435638427734, discriminator loss: 0.6115117101669312

Step 29500: Generator loss: 0.9797939239740372, discriminator loss: 0.6165106629133225

Step 30000: Generator loss: 0.9485270463228226, discriminator loss: 0.6146931504011154


Step 30500: Generator loss: 0.9519381281137467, discriminator loss: 0.6107044652700424

Step 31000: Generator loss: 0.9687439986467361, discriminator loss: 0.6125063039660453

Step 31500: Generator loss: 0.9376774097681045, discriminator loss: 0.6181296321153641

Step 32000: Generator loss: 0.93297694003582, discriminator loss: 0.6140522452592849

Step 32500: Generator loss: 0.9572884242534637, discriminator loss: 0.6195020067095757

Step 33000: Generator loss: 0.972777440071106, discriminator loss: 0.604750384092331

Step 33500: Generator loss: 0.9443573865890503, discriminator loss: 0.6129889862537384

Step 34000: Generator loss: 0.9689498499631882, discriminator loss: 0.6136127901673317

Step 34500: Generator loss: 0.9415038893222809, discriminator loss: 0.6130608377456666

Step 35000: Generator loss: 0.9240719108581543, discriminator loss: 0.61926220703125

Step 35500: Generator loss: 0.954082661986351, discriminator loss: 0.6173380060195923

Step 36000: Generator loss: 0.9727428333759308, discriminator loss: 0.6149510593414307

Step 36500: Generator loss: 0.9376445689201355, discriminator loss: 0.6174199160933495

Step 37000: Generator loss: 0.9389909583330155, discriminator loss: 0.6160642336010933

Step 37500: Generator loss: 0.9267619924545288, discriminator loss: 0.6168647689819335


Step 38000: Generator loss: 0.9334467047452927, discriminator loss: 0.6189262148141861

Step 38500: Generator loss: 0.9365294445753097, discriminator loss: 0.6181443547010421

Step 39000: Generator loss: 0.9170657801628113, discriminator loss: 0.6185943145155907

Step 39500: Generator loss: 0.9512229292392731, discriminator loss: 0.6173720422387123

Step 40000: Generator loss: 0.9405387555360794, discriminator loss: 0.6148940899372101

Step 40500: Generator loss: 0.9187011179924012, discriminator loss: 0.6186534762382507

Step 41000: Generator loss: 0.9478154428005219, discriminator loss: 0.6126384736299515

Step 41500: Generator loss: 0.9425623420476913, discriminator loss: 0.6214258141517639

Step 42000: Generator loss: 0.9260474470853806, discriminator loss: 0.6143858760595322

Step 42500: Generator loss: 0.9308006254434585, discriminator loss: 0.6138124242424965

Step 43000: Generator loss: 0.9477250524759293, discriminator loss: 0.6169488828778267

Step 43500: Generator loss: 0.9339873970746994, discriminator loss: 0.618168544948101

Step 44000: Generator loss: 0.9214147770404816, discriminator loss: 0.6160300189256668

Step 44500: Generator loss: 0.9398437255620956, discriminator loss: 0.6139895544648171

Step 45000: Generator loss: 0.9487018301486969, discriminator loss: 0.6198658150434494


Step 45500: Generator loss: 0.9378825527429581, discriminator loss: 0.6162218267321586

Step 46000: Generator loss: 0.9279566022157669, discriminator loss: 0.6161045470237732

Step 46500: Generator loss: 0.9248676340579987, discriminator loss: 0.618170117020607

Step 47000: Generator loss: 0.9279493466615677, discriminator loss: 0.6205574822425842

Step 47500: Generator loss: 0.946274598121643, discriminator loss: 0.6160013948678971

Step 48000: Generator loss: 0.944700514793396, discriminator loss: 0.6131612827181816

Step 48500: Generator loss: 0.9284282505512238, discriminator loss: 0.621739746928215

Step 49000: Generator loss: 0.9196328492164612, discriminator loss: 0.6140316212177277

Step 49500: Generator loss: 0.93618474817276, discriminator loss: 0.6146838818788528

Step 50000: Generator loss: 0.9389032748937607, discriminator loss: 0.6210263795256614

Step 50500: Generator loss: 0.9446513398885726, discriminator loss: 0.6144609298706055

Step 51000: Generator loss: 0.9266552208662033, discriminator loss: 0.61010345184803

Step 51500: Generator loss: 0.9457362469434738, discriminator loss: 0.6112923196554184

Step 52000: Generator loss: 0.9396196559667588, discriminator loss: 0.6137063590884209

Step 52500: Generator loss: 0.940164351940155, discriminator loss: 0.6108480334281922


Step 53000: Generator loss: 0.9524199627637863, discriminator loss: 0.6062074877619743

Step 53500: Generator loss: 0.9559306852817535, discriminator loss: 0.6107606259584427

Step 54000: Generator loss: 0.9457339569330215, discriminator loss: 0.6127307369709015

Step 54500: Generator loss: 0.96673697412014, discriminator loss: 0.6048597301244736

Step 55000: Generator loss: 0.9315944395065308, discriminator loss: 0.6112928040027619

Step 55500: Generator loss: 0.9514978064298629, discriminator loss: 0.6066118884086609

Step 56000: Generator loss: 0.9312148452997208, discriminator loss: 0.6096405319571495
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-13-b993e315ebb8> in <module>
     58         fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
     59         real_image_and_labels = combine_vectors(real, image_one_hot_labels)
---> 60         disc_fake_pred = disc(fake_image_and_labels.detach())
     61         disc_real_pred = disc(real_image_and_labels)
     62         #### END CODE HERE ####

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-3-b53366420424> in forward(self, image)
     45             image: a flattened image tensor with dimension (im_chan)
     46         '''
---> 47         disc_pred = self.disc(image)
     48         return disc_pred.view(len(disc_pred), -1)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/batchnorm.py in forward(self, input)
     96             # TODO: if statement only here to tell the jit to skip emitting this when it is None
     97             if self.num_batches_tracked is not None:
---> 98                 self.num_batches_tracked = self.num_batches_tracked + 1
     99                 if self.momentum is None:  # use cumulative moving average
    100                     exponential_average_factor = 1.0 / float(self.num_batches_tracked)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __setattr__(self, name, value)
    582                     del d[name]
    583 
--> 584         params = self.__dict__.get('_parameters')
    585         if isinstance(value, Parameter):
    586             if params is None:

KeyboardInterrupt: 

Exploration

You can do a bit of exploration now!

In [14]:
# Before you explore, you should put the generator
# in eval mode, both in general and so that batch norm
# doesn't cause you issues and is using its eval statistics
gen = gen.eval()

Changing the Class Vector

You can generate some numbers with your new model! You can add interpolation as well to make it more interesting.

So starting from a image, you will produce intermediate images that look more and more like the ending image until you get to the final image. Your're basically morphing one image into another. You can choose what these two images will be using your conditional GAN.

In [15]:
import math

### Change me! ###
n_interpolation = 9 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)
interpolation_noise = get_noise(1, z_dim, device=device).repeat(n_interpolation, 1)

def interpolate_class(first_number, second_number):
    first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
    second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)

    # Calculate the interpolation vector between the two labels
    percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
    interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label

    # Combine the noise and the labels
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)

### Change me! ###
start_plot_number = 1 # Choose the start digit
### Change me! ###
end_plot_number = 5 # Choose the end digit

plt.figure(figsize=(8, 8))
interpolate_class(start_plot_number, end_plot_number)
_ = plt.axis('off')

### Uncomment the following lines of code if you would like to visualize a set of pairwise class 
### interpolations for a collection of different numbers, all in a single grid of interpolations.
### You'll also see another visualization like this in the next code block!
# plot_numbers = [2, 3, 4, 5, 7]
# n_numbers = len(plot_numbers)
# plt.figure(figsize=(8, 8))
# for i, first_plot_number in enumerate(plot_numbers):
#     for j, second_plot_number in enumerate(plot_numbers):
#         plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
#         interpolate_class(first_plot_number, second_plot_number)
#         plt.axis('off')
# plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
# plt.show()
# plt.close()

Changing the Noise Vector

Now, what happens if you hold the class constant, but instead you change the noise vector? You can also interpolate the noise vector and generate an image at each step.

In [16]:
n_interpolation = 9 # How many intermediate images you want + 2 (for the start and end image)

# This time you're interpolating between the noise instead of the labels
interpolation_label = get_one_hot_labels(torch.Tensor([5]).long(), n_classes).repeat(n_interpolation, 1).float()

def interpolate_noise(first_noise, second_noise):
    # This time you're interpolating between the noise instead of the labels
    percent_first_noise = torch.linspace(0, 1, n_interpolation)[:, None].to(device)
    interpolation_noise = first_noise * percent_first_noise + second_noise * (1 - percent_first_noise)

    # Combine the noise and the labels again
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_label.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)

# Generate noise vectors to interpolate between
### Change me! ###
n_noise = 5 # Choose the number of noise examples in the grid
plot_noises = [get_noise(1, z_dim, device=device) for i in range(n_noise)]
plt.figure(figsize=(8, 8))
for i, first_plot_noise in enumerate(plot_noises):
    for j, second_plot_noise in enumerate(plot_noises):
        plt.subplot(n_noise, n_noise, i * n_noise + j + 1)
        interpolate_noise(first_plot_noise, second_plot_noise)
        plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()
In [ ]: